Topological Cluster Statistics (TCS)¶


Notebook 11: localized sensitivity (effect of cluster defining threshold)¶


This notebook contains scripts that evaluate the localized (voxel-wise) sensitivity improvements of TCS compared to cluster-based statistic at different cluster defining thresholds (supplementary analyses).


Packages and basic functions¶


Loading required packages

In [1]:
import os
import numpy as np
import pandas as pd
import nibabel as nib
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn import metrics
from tqdm.notebook import tqdm

Basic functions

In [2]:
def ensure_dir(file_name):
    os.makedirs(os.path.dirname(file_name), exist_ok=True)
    return file_name


def write_np(np_obj, file_path):
    with open(file_path, 'wb') as outfile:
        np.save(outfile, np_obj)


def load_np(file_path):
    with open(file_path, 'rb') as infile:
        return np.load(infile)

Plot settings (latex is used for better plotting)

In [3]:
sns.set()
sns.set_style("darkgrid")

%matplotlib inline
%config InlineBackend.figure_format = 'retina'
In [4]:
plt.rc('text', usetex=True)
plt.rc('text.latex', preamble=r'\usepackage{mathtools} \usepackage{sfmath}')

plt.rc('xtick', labelsize=20)
plt.rc('ytick', labelsize=20)
plt.rc('axes', labelsize=24)

plt.rc('figure', dpi=500)

Loading the ground truth¶


The ground truth stored in notebook 2 is loaded here.

In [5]:
# list of all tasks and the cope number related to each selected contrast
tasks = {
    'EMOTION': '3',  # faces - shapes
    'GAMBLING': '6',  # reward - punish
    'RELATIONAL': '4',  # rel - match
    'SOCIAL': '6',  # tom - random
    'WM': '20',  # face - avg
}
In [6]:
# Compute mean and std, followed by a parametric z-score (one sample t-test)
ground_truth_effect = {}
# Base directory where files are stored at
base_dir='/data/netapp01/work/sina/structural_clustering/PALM_revision_1'

for task in tqdm(tasks, desc="Tasks loop", leave=True):
    ground_truth_effect[task] = load_np(
        '{}/ground_truth/cohen_d_{}_cope{}.dscalar.npy'.format(base_dir, task, tasks[task]),
    )
Tasks loop:   0%|          | 0/5 [00:00<?, ?it/s]

Loading PALM results¶


PALM results stored in notebook 1 is loaded here.

In [7]:
%%time

# Number of random repetitions
repetitions = 500
# Different sample sizes tested
sample_sizes = [10, 20, 40, 80, 160, 320]
# Different cluster defining thresholds
cdts = [3.3, 2.8, 2.6, 2.0, 1.6]
# Number of brainordinates in a cifti file
Nv = 91282
# Base directory where files are stored at
base_dir='/data/netapp01/work/sina/structural_clustering/PALM_revision_1'

# Store loaded results in nested python dictionaries
loaded_maps = {}
loaded_maps['uncorrected_tstat'] = {}
loaded_maps['spatial_cluster_corrected_tstat'] = {}
loaded_maps['topological_cluster_corrected_tstat'] = {}

# Only use the z=3.3, p=0.001 for the main analyses reported here
# cdt = 3.3
sample_size = 40
for task in tqdm(tasks, desc="Tasks loop", leave=True):
    loaded_maps['uncorrected_tstat'][task] = {}
    loaded_maps['spatial_cluster_corrected_tstat'][task] = {}
    loaded_maps['topological_cluster_corrected_tstat'][task] = {}
    for cdt in tqdm(cdts, desc="CDT loop", leave=False):
        loaded_maps['uncorrected_tstat'][task][f'CDT={cdt}'] = load_np(
            f'{base_dir}/summary/uncorrected_tstat_{task}_{sample_size}_samples_{cdt}_CDT.npy',
        )
        loaded_maps['spatial_cluster_corrected_tstat'][task][f'CDT={cdt}'] = load_np(
            ensure_dir(f'{base_dir}/summary/spatial_cluster_corrected_tstat_{task}_{sample_size}_samples_{cdt}_CDT.npy'),
        )
        loaded_maps['topological_cluster_corrected_tstat'][task][f'CDT={cdt}'] = load_np(
            ensure_dir(f'{base_dir}/summary/topological_cluster_corrected_tstat_{task}_{sample_size}_samples_{cdt}_CDT.npy'),
        )
Tasks loop:   0%|          | 0/5 [00:00<?, ?it/s]
CDT loop:   0%|          | 0/5 [00:00<?, ?it/s]
CDT loop:   0%|          | 0/5 [00:00<?, ?it/s]
CDT loop:   0%|          | 0/5 [00:00<?, ?it/s]
CDT loop:   0%|          | 0/5 [00:00<?, ?it/s]
CDT loop:   0%|          | 0/5 [00:00<?, ?it/s]
CPU times: user 201 ms, sys: 31.1 s, total: 31.3 s
Wall time: 5min 13s

Localized sensitivity analyses¶

In [8]:
import scipy.stats as stats
import matplotlib.gridspec as gridspec
from scipy.interpolate import CubicSpline
from scipy.interpolate import UnivariateSpline
from statsmodels.stats.power import TTestPower
from matplotlib.patches import Patch
from matplotlib.image import NonUniformImage
import matplotlib.colors
from matplotlib.ticker import AutoMinorLocator

%config InlineBackend.figure_format = 'retina'

plt.rc('figure', dpi=300)

analysis = TTestPower()

fig = plt.figure(figsize=(30, 6*len(cdts)))

# outer grid
gs = fig.add_gridspec(len(cdts), 5, wspace=0.2, hspace=0.2)

sample_size = 40

sample_colors = np.array(sns.color_palette("rainbow", len(cdts)))

logp_threshold = -np.log10(0.05)

for ci, task in enumerate(tasks):
    method = 'difference'
    for ri, cdt in enumerate(cdts):
        inner_grid = gridspec.GridSpecFromSubplotSpec(6,6, subplot_spec=gs[ri, ci], wspace=0.2, hspace=0.2)
        ax = fig.add_subplot(inner_grid[1:, :-1])
        
        scatterx = ground_truth_effect[task]
        
        si = ri

        t_stats = loaded_maps['uncorrected_tstat'][task][f'CDT={cdt}']
        t_stats = t_stats[~np.isnan(t_stats).any(axis=1)]

        topological_cluster_logps = loaded_maps['topological_cluster_corrected_tstat'][task][f'CDT={cdt}']
        topological_cluster_logps = topological_cluster_logps[~np.isnan(topological_cluster_logps).any(axis=1)]
        topological_positive_effects = np.multiply(np.mean((topological_cluster_logps>logp_threshold) & (t_stats>0), 0), (ground_truth_effect[task]>0))
        topological_negative_effects = np.multiply(np.mean((topological_cluster_logps>logp_threshold) & (t_stats<0), 0), (ground_truth_effect[task]<0))

        spatial_cluster_logps = loaded_maps['spatial_cluster_corrected_tstat'][task][f'CDT={cdt}']
        spatial_cluster_logps = spatial_cluster_logps[~np.isnan(spatial_cluster_logps).any(axis=1)]
        spatial_positive_effects = np.multiply(np.mean((spatial_cluster_logps>logp_threshold) & (t_stats>0), 0), (ground_truth_effect[task]>0))
        spatial_negative_effects = np.multiply(np.mean((spatial_cluster_logps>logp_threshold) & (t_stats<0), 0), (ground_truth_effect[task]<0))

        topological_scattery = (topological_positive_effects + topological_negative_effects)
        spatial_scattery = (spatial_positive_effects + spatial_negative_effects)
        scattery = topological_scattery - spatial_scattery

        xlim = (-1.5,1.5)
        ylim = (-0.5,0.5)
        
        heatmap, xedges, yedges = np.histogram2d(scatterx, scattery, bins=(81, 81), range=[[-1.5,1.5],[-0.5,0.5]])
        extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
        
        ax.imshow(np.log(1 + heatmap.T), extent=extent, origin='lower', aspect='auto',
                  cmap=matplotlib.colors.LinearSegmentedColormap.from_list(
                      'my_gradient',
                      (
                          (0.0, (0.9, 0.9, 0.95,)),
                          (1.0, sample_colors[si]),
                      )
                  )
                 )

        ax.axhline(y=0.1, xmin=0, xmax=1, dashes=(3,3), color=(0.5,0.5,0.5,1), linewidth=2,)
        ax.axhline(y=-0.1, xmin=0, xmax=1, dashes=(3,3), color=(0.5,0.5,0.5,1), linewidth=2,)
        ax.axvline(x=0.2, ymin=0, ymax=1, dashes=(3,3), color=(0.5,0.5,0.5,1), linewidth=2,)
        ax.axvline(x=-0.2, ymin=0, ymax=1, dashes=(3,3), color=(0.5,0.5,0.5,1), linewidth=2,)

        ax.text(xlim[1]-0.3, ylim[1]-0.1, '\\textbf{{ {} }}'.format(np.sum((scattery>0.1)&(scatterx>0.2))), fontsize=16, ha='center', va='center')
        ax.text(xlim[0]+0.3, ylim[1]-0.1, '\\textbf{{ {} }}'.format(np.sum((scattery>0.1)&(scatterx<-0.2))), fontsize=16, ha='center', va='center')
        ax.text(xlim[1]-0.3, ylim[0]+0.1, '\\textbf{{ {} }}'.format(np.sum((scattery<-0.1)&(scatterx>0.2))), fontsize=16, ha='center', va='center')
        ax.text(xlim[0]+0.3, ylim[0]+0.1, '\\textbf{{ {} }}'.format(np.sum((scattery<-0.1)&(scatterx<-0.2))), fontsize=16, ha='center', va='center')
        
        # cubic spline fit
        bins = np.linspace(max(xlim[0], scatterx.min()), min(xlim[1], scatterx.max()), 31)
        digitized = np.digitize(scatterx, bins)
        x_means = [scatterx[(digitized == i) | (digitized == i + 1)].mean() for i in range(1, len(bins) - 1)]
        x_centers = bins[1:-1]
        y_means = [scattery[(digitized == i) | (digitized == i + 1)].mean() for i in range(1, len(bins) - 1)]
        y_sems = [stats.sem(scattery[(digitized == i) | (digitized == i + 1)]) for i in range(1, len(bins) - 1)]

        cs = CubicSpline(x_means, y_means, bc_type='natural', extrapolate=False)
        cs_sem = CubicSpline(x_means, y_sems, bc_type='natural', extrapolate=False)

        sample_x = np.linspace(scatterx.min(),scatterx.max(),200)
        sample_y = cs(sample_x)
        sample_y_sem = cs_sem(sample_x)

        sns.lineplot(
            x=sample_x,
            y=sample_y,
            style=True,
            dashes=[(1,1)],
            color=np.append(sample_colors[si]/2, 1),
            legend=False,
            linewidth=4,
        )

        ax.set_ylim(ylim)
        ax.set_xlim(xlim)

        xlabel = ''
        if ri == len(sample_sizes) - 1:
            xlabel = 'Effect size ($d$)'
        ax.set_xlabel(xlabel, fontsize=40)

        ylabel = ''
        if ci == 0:
            ylabel = '\\textbf{{ CDT={} }}'.format(cdt)
        ax.set_ylabel(ylabel, fontsize=40)
        
        ax.set_facecolor(np.array([234,234,242])/255)
        ax.grid(color=(0.99,0.99,0.99,), linewidth=0.1)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['bottom'].set_visible(False)
        ax.spines['left'].set_visible(False)
        ax.tick_params(axis='both', colors=(0.5,0.5,0.5), labelcolor=(0,0,0), direction='out')
        
        axu = fig.add_subplot(inner_grid[0, :-1], sharex=ax)
        plt.setp(axu.get_xticklabels(), visible=False)
        axu.grid(color=(0.99,0.99,0.99,), linewidth=0.1)
        axu.set_facecolor((0.9, 0.9, 0.95,))
        
        sns.histplot(x=scatterx, ax=axu, stat='probability', bins=81, binrange=(-1.5,1.5), color=sample_colors[si], element="step")
        axu.set_ylabel('', fontsize=40)
        axu.set_xlim(xlim)
        axu.set_ylim([0,0.2])

        axu.spines['top'].set_visible(False)
        axu.spines['right'].set_visible(False)
        axu.spines['bottom'].set_visible(False)
        axu.spines['left'].set_visible(False)
        
        axr = fig.add_subplot(inner_grid[1:, -1], sharey=ax)
        plt.setp(axr.get_yticklabels(), visible=False)
        axr.grid(color=(0.99,0.99,0.99,), linewidth=0.1)
        axr.set_facecolor((0.9, 0.9, 0.95,))
        
        sns.histplot(y=scattery, ax=axr, stat='probability', bins=81, binrange=(-0.5,0.5), color=sample_colors[si], element="step")
        axr.set_xlabel('', fontsize=40)
        axr.set_xscale('log')
        axr.set_ylim(ylim)

        axr.axhline(y=0.1, xmin=0, xmax=1, dashes=(3,3), color=(0.5,0.5,0.5,1), linewidth=2,)
        axr.axhline(y=-0.1, xmin=0, xmax=1, dashes=(3,3), color=(0.5,0.5,0.5,1), linewidth=2,)
        
        axr.text(0.01, ylim[1]-0.1, '\\textbf{{ {} }}'.format(np.sum((scattery>0.1))), fontsize=16, ha='center', va='center')
        axr.text(0.01, ylim[0]+0.1, '\\textbf{{ {} }}'.format(np.sum((scattery<-0.1))), fontsize=16, ha='center', va='center')
        axr.set_xticks([0.001, 0.1])
        axr.set_xticklabels([0.001, 0.1], fontsize=16, rotation=55)
        axr.tick_params(axis="x", direction="inout", bottom=True, length=10, width=2)
        
        axr.spines['top'].set_visible(False)
        axr.spines['right'].set_visible(False)
        axr.spines['bottom'].set_visible(False)
        axr.spines['left'].set_visible(False)
        
plt.show()